import numpy as np


class simulannealbnd:
    """
    example:
    >> import numpy as np
    >> def func(x):
    >>     y=-np.sin(10*np.pi*x[0]) / x[1]
    >>     return y


    >>s = simulannealbnd(T_max=100,T_min=1)
    >> print(s.fit(func,np.array([1,1]),[1,1],[2,2]))
    >> [out:] (array([1.04946499, 1.0014167 ]), -0.9984442567456814)
    """
    def __init__(self,T_max=300,T_min=1,K=100):
        assert T_max>T_min>0

        self.T_max = T_max
        self.T_min = T_min
        self.K = K
        self.t = 0
        self.T = self.T_max


    def fit(self,f,x0,lb,ub):
        self.f = f
        self.x_best = x0
        self.lb = lb
        self.ub = ub
        self.dim = len(x0)
        self.x_best = x0
        self.y_best = self.f(self.x_best)
        self.y_best_history = [self.y_best]
        self.x_best_history = [self.x_best]

        x_, y_ = self.x_best, self.y_best
        
        while self.T>self.T_min:
            for _ in range(self.K):
                x_new = self.__get_new_x(x_)
                y_new = self.f(x_new)
                if y_new - y_ <0:
                    x_ = x_new
                    y_ = y_new
                    if y_ < self.y_best and np.all(self.lb<x_) and np.all(x_<self.ub):
                        self.y_best = y_new
                        self.x_best = x_new
                        self.y_best_history.append(y_new)
                        self.x_best_history.append(x_new)
                elif y_new-y_>=0 and np.exp(-(y_new-y_) / self.T) > np.random.rand():
                    x_ = x_new
                    y_ = y_new
                    if y_ < self.y_best and np.all(self.lb<x_) and np.all(x_<self.ub):
                        self.y_best = y_new
                        self.x_best = x_new
                        self.y_best_history.append(y_new)
                        self.x_best_history.append(x_new)

            self.__cool_down()
            self.t += 1
            y_min = min(self.y_best_history)
            for i,data in enumerate(self.y_best_history):
                if data == y_min:
                    return self.x_best_history[i],data
    def __cool_down(self):
        self.T = self.T_max / (1 + self.t)

    def __get_new_x(self,x):
        r = np.random.uniform(-0.01, 0.01, size=self.dim)
        xc = np.sign(r) * self.T * ((1 + 1.0 / self.T) ** np.abs(r) - 1.0)
        x_new = x + xc
        if np.any(x_new<self.lb) or np.any(x_new.all()>self.ub):
            x_new = x_new - 2* xc
        return x_new



class simulanneal_TSP(simulannealbnd):
    def __init__(self,T_max=300,T_min=1,K=100):
        super().__init__(T_max=300,T_min=1,K=100)
        assert T_max>T_min>0
        #self.T_max = T_max
        #self.T_min = T_min
        #self.K = K
        #self.t = 0
        #self.T = self.T_max

    def fit(self,f,dist,x0):
        """
        
        """
        self.f = f
        self.x_best = x0

        self.dim = len(x0)
        self.x_best = x0
        self.y_best = self.f(self.x_best,dist)
        self.y_best_history = [self.y_best]
        self.x_best_history = [self.x_best]

        x_, y_ = self.x_best, self.y_best
        
        while self.T>self.T_min:
            for _ in range(self.K):
                x_new = self.__get_new_x(x_)
                y_new = self.f(x_new,dist)
                if y_new - y_ <0:
                    x_ = x_new
                    y_ = y_new
                    if y_ < self.y_best:
                        self.y_best = y_new
                        self.x_best = x_new
                        self.y_best_history.append(y_new)
                        self.x_best_history.append(x_new)
                elif y_new-y_>=0 and np.exp(-(y_new-y_) / self.T) > np.random.rand():
                    x_ = x_new
                    y_ = y_new
                    if y_ < self.y_best:
                        self.y_best = y_new
                        self.x_best = x_new
                        self.y_best_history.append(y_new)
                        self.x_best_history.append(x_new)

            self.__cool_down()
            self.t += 1
        y_min = min(self.y_best_history)
        for i,data in enumerate(self.y_best_history):
            if data == y_min:
                return self.x_best_history[i],data

    def __get_new_x(self,x):
        import random
        n = len(x)
        u,v = random.sample(range(0,n),2)
        x_new = np.copy(x)
        temp = x[u:v]
        temp = temp[::-1]
        x_new[u:v] = temp
        return x_new